{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Beyond Pytorch Native APIs",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GxDkdWhYSiLi",
        "colab_type": "text"
      },
      "source": [
        "# Beyond Pytorch Native APIs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ded3RflTZnZ8",
        "colab_type": "text"
      },
      "source": [
        "The XLA backend of PyTorch allows an end user to create functions whose implementation is totally controlled by the user's Python code itself, in terms of the lower level XLA HLO operation generated.\n",
        "\n",
        "The *xla_builder* module provides a slim wrapper around the *xla::XlaOp* objects documented within the XLA reference:\n",
        "\n",
        "https://www.tensorflow.org/xla/operation_semantics\n",
        "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/xla_builder.h\n",
        "\n",
        "While this allows the user to create APIs whose semantics are not currently available in PyTorch, such APIs will only work when used with an XLA device.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iT6DRgBjeVbY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jaG6R8OWeZju",
        "colab_type": "text"
      },
      "source": [
        "### Only run the below commented cell if you would like a nightly release"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sPJVqAKyml5W",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# VERSION = \"nightly\"  #@param [\"1.5\" , \"20200325\", \"nightly\"]\n",
        "# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n",
        "# !python pytorch-xla-env-setup.py --version $VERSION"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zQ2_OcQxMEI8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "import torch_xla\n",
        "import torch_xla.core.xla_builder as xb\n",
        "import torch_xla.core.xla_op_registry as xor\n",
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "\n",
        "\n",
        "# Splits a rank 1 tensor into the scalar indices required by the XLA dynamic\n",
        "# slicing APIs.\n",
        "def _split_indices(index):\n",
        "  ishape = index.shape()\n",
        "  assert ishape.rank == 1\n",
        "  indices = []\n",
        "  for dim in range(0, ishape.sizes[0]):\n",
        "    indices.append(index.slice_in_dim(dim, dim + 1, 0).reshape([]))\n",
        "  return indices\n",
        "\n",
        "\n",
        "# This is the XLA lowering API. Here input and start_indices are Op object of the\n",
        "# xla_builder module and can be manipulated with such API.\n",
        "#   https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_builder.py\n",
        "def _dynamic_slice_forward(input, start_indices, slice_sizes=None):\n",
        "  return input.dynamic_slice(_split_indices(start_indices), slice_sizes)\n",
        "\n",
        "\n",
        "# This is the XLA lowering API. Here grad_output, input and start_indices are Op\n",
        "# object of the xla_builder module and can be manipulated with such API.\n",
        "#   https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_builder.py\n",
        "def _dynamic_slice_backward(grad_output, input, start_indices, slice_sizes=None):\n",
        "  return input.zeros_like().dynamic_update_slice(grad_output, _split_indices(start_indices))\n",
        "\n",
        "\n",
        "# For efficiency, it is better to register the XLA builder based operations at\n",
        "# global scope.\n",
        "DYNAMIC_SLICE_FORWARD = xor.register('DynamicSliceForward', _dynamic_slice_forward)\n",
        "DYNAMIC_SLICE_BACKWARD = xor.register('DynamicSliceBackward', _dynamic_slice_backward)\n",
        "\n",
        "\n",
        "# Standard PyTorch way to create a differentiable function.\n",
        "class DynamicSlice(torch.autograd.Function):\n",
        "  @staticmethod\n",
        "  def forward(ctx, input, start_indices, slice_sizes):\n",
        "    ctx.slice_sizes = slice_sizes\n",
        "    output = DYNAMIC_SLICE_FORWARD(input, start_indices, slice_sizes=slice_sizes)\n",
        "    ctx.save_for_backward(input, start_indices)\n",
        "    return output\n",
        "\n",
        "  @staticmethod\n",
        "  def backward(ctx, grad_output):\n",
        "    input, start_indices = ctx.saved_tensors\n",
        "    grad = DYNAMIC_SLICE_BACKWARD(grad_output, input, start_indices,\n",
        "                                  slice_sizes=ctx.slice_sizes)\n",
        "    # We need to return as many gradients as the forward() inputs, or None if\n",
        "    # such inputs are not differentiable.\n",
        "    return grad, None, None\n",
        "\n",
        "\n",
        "# Exposes the dynamic slice operation, which will support autograd differentation.\n",
        "def dynamic_slice(input, start_indices, slice_sizes):\n",
        "  \"\"\"Slices an input tensor.\n",
        "\n",
        "  Args:\n",
        "    input (torch.Tensor): The input tensor to be sliced.\n",
        "    start_indices (torch.Tensor): The rank 1 tensor containing the start indices.\n",
        "      The size of the tensor (its dimension 0) must be the same as the rank of\n",
        "      the input tensor.\n",
        "    slice_sizes (list, int): The sizes of the slices. This is a list of Python\n",
        "      integers, whose lenght must be the same of the rank of the input tensor.\n",
        "  Returns:\n",
        "    The sliced input tensor.\n",
        "  \"\"\"\n",
        "  return DynamicSlice.apply(input, start_indices, slice_sizes)\n",
        "\n",
        "\n",
        "# Test implementation.\n",
        "device = xm.xla_device()\n",
        "\n",
        "x = torch.randn(6, 8, device=device, requires_grad=True)\n",
        "index = torch.tensor([2, 4], dtype=torch.int32, device=device)\n",
        "out = dynamic_slice(x, index, (2, 3))\n",
        "loss = out.pow(2).sum()\n",
        "loss.backward()\n",
        "print(x.grad.cpu())"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
